{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fb0af604",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "34818543",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 2, 3]), torch.Size([1, 1, 3]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_rnn = nn.RNN(4, 3,1, batch_first=True)\n",
    "input = torch.randn(1,2,4)\n",
    "output, h_n = single_rnn(input)\n",
    "output.shape, h_n.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "42f5382c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 2, 6]), torch.Size([2, 1, 3]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch_rnn =nn.RNN(4, 3,1, batch_first=True, bidirectional=True)\n",
    "input = torch.randn(1,2,4)\n",
    "output, h_n = torch_rnn(input)\n",
    "output.shape, h_n.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc3798ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, seq_len = 2, 3\n",
    "# 特征大小\n",
    "input_size, hidden_size=  2, 3\n",
    "\n",
    "input = torch.randn(batch_size, seq_len, input_size)\n",
    "h_prev_back = torch.zeros((1, batch_size, hidden_size))\n",
    "torch.manual_seed(20250630)\n",
    "torch.cuda.manual_seed(20250630)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a49d9166",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 3]) torch.Size([1, 2, 3])\n",
      "---Parameters---\n",
      "weight_ih_l0 torch.Size([3, 2])\n",
      "weight_hh_l0 torch.Size([3, 3])\n",
      "bias_ih_l0 torch.Size([3])\n",
      "bias_hh_l0 torch.Size([3])\n",
      "---Output---\n",
      "tensor([[[ 0.7200,  0.2839,  0.1090],\n",
      "         [ 0.8598,  0.5006, -0.0296],\n",
      "         [ 0.9018,  0.3565,  0.0851]],\n",
      "\n",
      "        [[ 0.7355, -0.1017, -0.2827],\n",
      "         [ 0.8609,  0.5447, -0.2269],\n",
      "         [ 0.9237,  0.3352,  0.6465]]], grad_fn=<TransposeBackward1>)\n",
      "tensor([[[0.9018, 0.3565, 0.0851],\n",
      "         [0.9237, 0.3352, 0.6465]]], grad_fn=<StackBackward0>)\n"
     ]
    }
   ],
   "source": [
    "rnn_real = nn.RNN(input_size, hidden_size, batch_first=True)\n",
    "out_real, h_real = rnn_real(input, h_prev_back)\n",
    "print(out_real.shape, h_real.shape)\n",
    "print(\"---Parameters---\")\n",
    "for k, v in rnn_real.named_parameters():\n",
    "    print(k, v.shape)\n",
    "print(\"---Output---\")\n",
    "print(out_real)\n",
    "print(h_real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0572a53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import Tensor\n",
    "\n",
    "\n",
    "class RnnCell(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, activ=nn.Tanh()):\n",
    "        super(RnnCell, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.activ = activ\n",
    "        self.w_ih = nn.Parameter(torch.randn(hidden_size, input_size))\n",
    "        self.w_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))\n",
    "        self.b_ih = nn.Parameter(torch.randn(hidden_size))\n",
    "        self.b_hh = nn.Parameter(torch.randn(hidden_size))\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        nn.init.xavier_normal_(self.w_ih)\n",
    "        nn.init.xavier_normal_(self.w_hh)\n",
    "        nn.init.constant_(self.b_ih, 0.1)\n",
    "\n",
    "    def forward(self, x_t: Tensor, h_prev: Tensor) -> Tensor:\n",
    "        h = self.activ(x_t @ self.w_ih.T + self.b_ih + h_prev @ self.w_hh.T + self.b_hh)\n",
    "        return h\n",
    "\n",
    "\n",
    "class BiRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(BiRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn1 = RnnCell(input_size, hidden_size, activ=nn.Tanh())\n",
    "        self.rnn2 = RnnCell(input_size, hidden_size, activ=nn.Tanh())\n",
    "\n",
    "        self.fc = nn.Linear(hidden_size * 2, output_size)\n",
    "        self.leaky_relu = nn.LeakyReLU()\n",
    "\n",
    "    def forward1(self, x: Tensor, h_prev_forward=None):\n",
    "        \"\"\"前向RNN\"\"\"\n",
    "        batch_size_, seq_len_, _ = x.shape\n",
    "        h_out = torch.zeros((batch_size_, seq_len_, self.hidden_size))\n",
    "        if h_prev_forward is None:\n",
    "            h_prev_forward = torch.zeros(batch_size_, self.hidden_size).to(x.device)\n",
    "        for i in range(seq_len):\n",
    "            x_t = x[:, i, :]\n",
    "            h_prev_forward = self.rnn2(x_t, h_prev_forward)\n",
    "            # y = self.leaky_relu((h_prev @ self.w_y) + self.b_y)\n",
    "            h_out[:, i, :] = h_prev_forward\n",
    "        return h_out, h_prev_forward\n",
    "\n",
    "    def forward2(self, x: Tensor, h_prev_back=None):\n",
    "        \"\"\"后RNN\"\"\"\n",
    "        batch_size_, seq_len_, _ = x.shape\n",
    "        h_out = torch.zeros((batch_size_, seq_len_, self.hidden_size))\n",
    "        if h_prev_back is None:\n",
    "            h_prev_back = torch.zeros(batch_size_, self.hidden_size).to(x.device)\n",
    "        for i in range(seq_len - 1, -1, -1):\n",
    "            x_t = x[:, i, :]\n",
    "            h_prev_back = self.rnn1(x_t, h_prev_back)\n",
    "            # y = self.leaky_relu((h_prev @ self.w_y) + self.b_y)\n",
    "            h_out[:, i, :] = h_prev_back\n",
    "        return h_out, h_prev_back\n",
    "\n",
    "    def forward(self, x: Tensor, h_prev_forward=None, h_prev_back=None):\n",
    "        h_out_forw, h_prev_forward = self.forward1(x, h_prev_forward)\n",
    "        h_out_back, h_prev_back = self.forward2(x, h_prev_back)\n",
    "        # shape: (batch_size, seq_len, hidden_size * 2)\n",
    "        h_out = torch.cat((h_out_forw, h_out_back), dim=-1)\n",
    "        # shape: (2, batch_size, hidden_size)\n",
    "        h_final = torch.stack([h_prev_forward, h_prev_back], dim=0)\n",
    "        return h_out, h_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99d20bab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2, 3]) torch.Size([2, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "my_rnn = BiRNN(input_size, hidden_size, hidden_size)\n",
    "# for k, v in my_rnn.named_parameters():\n",
    "#     print(k, v.shape)\n",
    "out, h = my_rnn(input, h_prev_back)\n",
    "print(h.shape, out.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c6d18da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.7200,  0.2839,  0.1090],\n",
      "         [ 0.8598,  0.5006, -0.0296],\n",
      "         [ 0.9018,  0.3565,  0.0851]],\n",
      "\n",
      "        [[ 0.7355, -0.1017, -0.2827],\n",
      "         [ 0.8609,  0.5447, -0.2269],\n",
      "         [ 0.9237,  0.3352,  0.6465]]], grad_fn=<CopySlices>)\n",
      "tensor([[[0.9018, 0.3565, 0.0851],\n",
      "         [0.9237, 0.3352, 0.6465]]], grad_fn=<TanhBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# weight_ih_l0 torch.Size([3, 2])\n",
    "# weight_hh_l0 torch.Size([3, 3])\n",
    "# bias_ih_l0 torch.Size([3])\n",
    "# bias_hh_l0 torch.Size([3])\n",
    "\n",
    "my_rnn = BiRNN(input_size, hidden_size, hidden_size)\n",
    "my_rnn.rnn1.w_ih = rnn_real.weight_ih_l0\n",
    "my_rnn.rnn1.w_hh = rnn_real.weight_hh_l0\n",
    "my_rnn.rnn1.b_ih = rnn_real.bias_ih_l0\n",
    "my_rnn.rnn1.b_hh = rnn_real.bias_hh_l0\n",
    "# for k, v in my_rnn.named_parameters():\n",
    "#     print(k, v.shape)\n",
    "out, h = my_rnn(input, h_prev_back)\n",
    "print(out)\n",
    "print(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4975b0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.allclose(out_real, out)\n",
    "assert torch.allclose(h_real, h)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
